SFT 流水线
目录
✨️概述
此流水线用于监督微调(SFT),提供:
- 统一的数据编码与对话模板:支持 system/user/assistant 对话格式拼接,并自动构造
labels(仅对回答部分计 loss)。 - 高效分布式训练:使用 Ray + Cluster/Worker 抽象启动分布式训练。
- 全面的性能监控:细粒度度量跟踪系统,监控性能指标,为模型训练过程提供全面的可视化和分析能力。
- 高效训练优化:支持 Sequence Packing(将多条短样本拼接成连续序列,减少 padding)。配置方法和实现原理详见
sequence packing对应文档。
✨️核心组件
主模块(SFTPipeline)
SFTPipeline(位于 roll/pipeline/sft/sft_pipeline.py)是 SFT 训练的主流程,负责:
- 加载 tokenizer。
- 加载训练数据集 与(可选)验证数据集。
- 按模板编码数据:生成
input_ids/attention_mask/labels。 - 初始化分布式训练集群(
Cluster+SFTWorker)。 - 训练循环:按 step 训练、按
eval_steps验证、按保存策略写 checkpoint、记录指标并上报 tracker。
工作器(SFTWorker)
SFTWorker(位于 roll/pipeline/sft/sft_worker.py)负责执行训练、验证与保存:
initialize():创建并初始化分布式策略(create_strategy),并加载模型。train_step():执行一次训练 step,返回训练 metrics。val_step():执行一次验证 step(前向 + loss),返回验证 metrics。do_checkpoint():保存 checkpoint,并返回保存耗时等 metrics。
配置文件(SFTConfig)
SFTConfig(定义于 roll/pipeline/sft/sft_config.py)是 SFT 流水线的配置对象(dataclass 风格),支持通过 YAML + Hydra 管理。